- Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathELA_Training_Module_Final.py
122 lines (97 loc) · 4.62 KB
/
ELA_Training_Module_Final.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
fromsklearn.model_selectionimporttrain_test_split
fromsklearn.metricsimportconfusion_matrix
fromkeras.utils.np_utilsimportto_categorical
fromkeras.modelsimportSequential
fromkeras.layersimportDense, Dropout, Flatten, Conv2D, MaxPool2D
fromkeras.optimizersimportRMSprop
frompylabimport*
fromPILimportImage, ImageChops, ImageEnhance
importpandasaspd
importnumpyasnp
importitertools
importmatplotlib.pyplotasplt
importos
importrandom
deftrain_Ela_Model(csv_file , lr , ep):
defconvert_to_ela_image(path, quality):
filename=path
resaved_filename=filename.split('.')[0] +'.resaved.jpg'
im=Image.open(filename).convert('RGB')
im.save(resaved_filename, 'JPEG', quality=quality)
resaved_im=Image.open(resaved_filename)
ela_im=ImageChops.difference(im, resaved_im)
extrema=ela_im.getextrema()
max_diff=max([ex[1] forexinextrema])
ifmax_diff==0:
max_diff=1
scale=255.0/max_diff
ela_im=ImageEnhance.Brightness(ela_im).enhance(scale)
returnela_im
dataset=pd.read_csv(csv_file)
X= []
Y= []
forindex, rowindataset.iterrows():
X.append(array(convert_to_ela_image(row[0], 90).resize((128, 128))).flatten() /255.0)
Y.append(row[1])
X=np.array(X)
Y=to_categorical(Y, 2)
X=X.reshape(-1, 128, 128, 3)
X_train, X_val, Y_train, Y_val=train_test_split(X, Y, test_size=0.1, random_state=5, shuffle=True)
model=Sequential()
model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='valid', activation='relu', input_shape=(128, 128, 3)))
model.add(Conv2D(filters=32, kernel_size=(5, 5), strides=(2, 2), padding='valid', activation='relu'))
model.add(MaxPool2D(pool_size=2, strides=None, padding='valid', data_format='channels_last'))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation="relu"))
model.add(Dropout(0.50))
model.add(Dense(2, activation="softmax"))
model.summary()
optimizer=RMSprop(lr=lr, rho=0.9, epsilon=1e-08, decay=0.0)
model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"])
#early_stopping = EarlyStopping(monitor='val_acc', min_delta=0, patience=2, verbose=0, mode='auto')
epochs=ep
batch_size=5
history=model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, validation_data=(X_val, Y_val), verbose=2)
fig, ax=plt.subplots(3, 1)
ax[0].plot(history.history['loss'], color='b', label="Training loss")
ax[0].plot(history.history['val_loss'], color='r', label="validation loss", axes=ax[0])
legend=ax[0].legend(loc='best', shadow=True)
ax[1].plot(history.history['acc'], color='b', label="Training accuracy")
ax[1].plot(history.history['val_acc'], color='r', label="Validation accuracy")
legend_=ax[1].legend(loc='best', shadow=True)
defplot_confusion_matrix(cm_, classes, normalize=False, title_='Confusion matrix', cmap=cm.get_cmap("Spectral")):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm_, interpolation='nearest', cmap=cmap)
plt.title(title_)
plt.colorbar()
tick_marks=np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
ifnormalize:
cm_=cm_.astype('float') /cm_.sum(axis=1)[:, np.newaxis]
thresh=cm_.max() /2.
fori, jinitertools.product(range(cm_.shape[0]), range(cm_.shape[1])):
plt.text(j, i, cm_[i, j],
horizontalalignment="center",
color="white"ifcm_[i, j] >threshelse"black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
Y_pred=model.predict(X_val)
Y_pred_classes=np.argmax(Y_pred, axis=1)
Y_true=np.argmax(Y_val, axis=1)
confusion_mtx=confusion_matrix(Y_true, Y_pred_classes)
plot_confusion_matrix(confusion_mtx, classes=range(2))
#plt.show()
image_path=os.getcwd()+"\\Figures"
Models_path=os.getcwd()+"\\Re_Traind_Models"
file_number=random.randint(1, 1000000)
plot_Name=image_path+"\\ELA_"+str(file_number)+".png"
Model_Name=Models_path+"\\ELA_"+str(file_number)+".h5"
plt.savefig(plot_Name , transparent=True , bbox_incehs="tight" , pad_inches=2 , dpi=50)
model.save(Model_Name)
returnplot_Name , Model_Name